-
-
Notifications
You must be signed in to change notification settings - Fork 647
adds available_device to test_precision_recall_curve #3335 #3368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
adds available_device to test_precision_recall_curve #3335 #3368
Conversation
(otherwize type error on on MPS)
|
||
assert pytest.approx(precision) == sk_precision | ||
assert pytest.approx(recall) == sk_recall | ||
assert np.allclose(precision, sk_precision, rtol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from what I understand, pytest.approx may convert float32 parameter into float64. This would break on MPS
…ision, recall and thresholds
def to_numpy_float32(x): | ||
if isinstance(x, torch.Tensor): | ||
if x.device.type == "mps": | ||
x = x.to("cpu") # Explicitly move from MPS to CPU | ||
return x.detach().to(dtype=torch.float32).numpy() | ||
elif isinstance(x, np.ndarray): | ||
return x.astype(np.float32) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use this pattern but use what we did previously
No description provided.